{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Counterfactual explanations with ordinally encoded categorical variables"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This example notebook illustrates how to obtain [counterfactual explanations](https://docs.seldon.io/projects/alibi/en/stable/methods/CFProto.html) for instances with a mixture of ordinally encoded categorical and numerical variables. A more elaborate notebook highlighting additional functionality can be found [here](./cfproto_cat_adult_ohe.ipynb). We generate counterfactuals for instances in the *adult* dataset where we predict whether a person's income is above or below $50k."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-info\">\n",
    "Note\n",
    "    \n",
    "To enable support for CounterfactualProto, you may need to run\n",
    "    \n",
    "```bash\n",
    "pip install alibi[tensorflow]\n",
    "```\n",
    "\n",
    "</div>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TF version:  2.2.0\n",
      "Eager execution enabled:  False\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "os.environ[\"TF_USE_LEGACY_KERAS\"] = \"1\"\n",
    "\n",
    "import tensorflow as tf\n",
    "tf.get_logger().setLevel(40) # suppress deprecation messages\n",
    "tf.compat.v1.disable_v2_behavior() # disable TF2 behaviour as alibi code still relies on TF1 constructs \n",
    "from tensorflow.keras.layers import Dense, Input, Embedding, Concatenate, Reshape, Dropout, Lambda\n",
    "from tensorflow.keras.models import Model\n",
    "from tensorflow.keras.utils import to_categorical\n",
    "\n",
    "%matplotlib inline\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import os\n",
    "from sklearn.preprocessing import OneHotEncoder\n",
    "from time import time\n",
    "from alibi.datasets import fetch_adult\n",
    "from alibi.explainers import CounterfactualProto\n",
    "\n",
    "print('TF version: ', tf.__version__)\n",
    "print('Eager execution enabled: ', tf.executing_eagerly()) # False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load adult dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `fetch_adult` function returns a `Bunch` object containing the features, the targets, the feature names and a mapping of the categories in each categorical variable."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "adult = fetch_adult()\n",
    "data = adult.data\n",
    "target = adult.target\n",
    "feature_names = adult.feature_names\n",
    "category_map_tmp = adult.category_map\n",
    "target_names = adult.target_names"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Define shuffled training and test set:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def set_seed(s=0):\n",
    "    np.random.seed(s)\n",
    "    tf.random.set_seed(s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "set_seed()\n",
    "data_perm = np.random.permutation(np.c_[data, target])\n",
    "X = data_perm[:,:-1]\n",
    "y = data_perm[:,-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "idx = 30000\n",
    "y_train, y_test = y[:idx], y[idx+1:]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Reorganize data so categorical features come first:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = np.c_[X[:, 1:8], X[:, 11], X[:, 0], X[:, 8:11]]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Adjust `feature_names` and `category_map` as well:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['Workclass', 'Education', 'Marital Status', 'Occupation', 'Relationship', 'Race', 'Sex', 'Country', 'Age', 'Capital Gain', 'Capital Loss', 'Hours per week']\n"
     ]
    }
   ],
   "source": [
    "feature_names = feature_names[1:8] + feature_names[11:12] + feature_names[0:1] + feature_names[8:11]\n",
    "print(feature_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "category_map = {}\n",
    "for i, (_, v) in enumerate(category_map_tmp.items()):\n",
    "    category_map[i] = v"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create a dictionary with as keys the categorical columns and values the number of categories for each variable in the dataset. This dictionary will later be used in the counterfactual explanation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{0: 9, 1: 7, 2: 4, 3: 9, 4: 6, 5: 5, 6: 2, 7: 11}\n"
     ]
    }
   ],
   "source": [
    "cat_vars_ord = {}\n",
    "n_categories = len(list(category_map.keys()))\n",
    "for i in range(n_categories):\n",
    "    cat_vars_ord[i] = len(np.unique(X[:, i]))\n",
    "print(cat_vars_ord)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Preprocess data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Scale numerical features between -1 and 1:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_num = X[:, -4:].astype(np.float32, copy=False)\n",
    "xmin, xmax = X_num.min(axis=0), X_num.max(axis=0)\n",
    "rng = (-1., 1.)\n",
    "X_num_scaled = (X_num - xmin) / (xmax - xmin) * (rng[1] - rng[0]) + rng[0]\n",
    "X_num_scaled_train = X_num_scaled[:idx, :]\n",
    "X_num_scaled_test = X_num_scaled[idx+1:, :]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Combine numerical and categorical data:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(30000, 12) (2560, 12)\n"
     ]
    }
   ],
   "source": [
    "X = np.c_[X[:, :-4], X_num_scaled].astype(np.float32, copy=False)\n",
    "X_train, X_test = X[:idx, :], X[idx+1:, :]\n",
    "print(X_train.shape, X_test.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train a neural net"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The neural net will use entity embeddings for the categorical variables."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def nn_ord():\n",
    "    \n",
    "    x_in = Input(shape=(12,))\n",
    "    layers_in = []\n",
    "    \n",
    "    # embedding layers\n",
    "    for i, (_, v) in enumerate(cat_vars_ord.items()):\n",
    "        emb_in = Lambda(lambda x: x[:, i:i+1])(x_in)\n",
    "        emb_dim = int(max(min(np.ceil(.5 * v), 50), 2))\n",
    "        emb_layer = Embedding(input_dim=v+1, output_dim=emb_dim, input_length=1)(emb_in)\n",
    "        emb_layer = Reshape(target_shape=(emb_dim,))(emb_layer)\n",
    "        layers_in.append(emb_layer)\n",
    "        \n",
    "    # numerical layers\n",
    "    num_in = Lambda(lambda x: x[:, -4:])(x_in)\n",
    "    num_layer = Dense(16)(num_in)\n",
    "    layers_in.append(num_layer)\n",
    "    \n",
    "    # combine\n",
    "    x = Concatenate()(layers_in)\n",
    "    x = Dense(60, activation='relu')(x)\n",
    "    x = Dropout(.2)(x)\n",
    "    x = Dense(60, activation='relu')(x)\n",
    "    x = Dropout(.2)(x)\n",
    "    x = Dense(60, activation='relu')(x)\n",
    "    x = Dropout(.2)(x)\n",
    "    x_out = Dense(2, activation='softmax')(x)\n",
    "    \n",
    "    nn = Model(inputs=x_in, outputs=x_out)\n",
    "    nn.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])\n",
    "    \n",
    "    return nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"model\"\n",
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "input_1 (InputLayer)            [(None, 12)]         0                                            \n",
      "__________________________________________________________________________________________________\n",
      "lambda (Lambda)                 (None, 1)            0           input_1[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "lambda_1 (Lambda)               (None, 1)            0           input_1[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "lambda_2 (Lambda)               (None, 1)            0           input_1[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "lambda_3 (Lambda)               (None, 1)            0           input_1[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "lambda_4 (Lambda)               (None, 1)            0           input_1[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "lambda_5 (Lambda)               (None, 1)            0           input_1[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "lambda_6 (Lambda)               (None, 1)            0           input_1[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "lambda_7 (Lambda)               (None, 1)            0           input_1[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "embedding (Embedding)           (None, 1, 5)         50          lambda[0][0]                     \n",
      "__________________________________________________________________________________________________\n",
      "embedding_1 (Embedding)         (None, 1, 4)         32          lambda_1[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "embedding_2 (Embedding)         (None, 1, 2)         10          lambda_2[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "embedding_3 (Embedding)         (None, 1, 5)         50          lambda_3[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "embedding_4 (Embedding)         (None, 1, 3)         21          lambda_4[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "embedding_5 (Embedding)         (None, 1, 3)         18          lambda_5[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "embedding_6 (Embedding)         (None, 1, 2)         6           lambda_6[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "embedding_7 (Embedding)         (None, 1, 6)         72          lambda_7[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "lambda_8 (Lambda)               (None, 4)            0           input_1[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "reshape (Reshape)               (None, 5)            0           embedding[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "reshape_1 (Reshape)             (None, 4)            0           embedding_1[0][0]                \n",
      "__________________________________________________________________________________________________\n",
      "reshape_2 (Reshape)             (None, 2)            0           embedding_2[0][0]                \n",
      "__________________________________________________________________________________________________\n",
      "reshape_3 (Reshape)             (None, 5)            0           embedding_3[0][0]                \n",
      "__________________________________________________________________________________________________\n",
      "reshape_4 (Reshape)             (None, 3)            0           embedding_4[0][0]                \n",
      "__________________________________________________________________________________________________\n",
      "reshape_5 (Reshape)             (None, 3)            0           embedding_5[0][0]                \n",
      "__________________________________________________________________________________________________\n",
      "reshape_6 (Reshape)             (None, 2)            0           embedding_6[0][0]                \n",
      "__________________________________________________________________________________________________\n",
      "reshape_7 (Reshape)             (None, 6)            0           embedding_7[0][0]                \n",
      "__________________________________________________________________________________________________\n",
      "dense (Dense)                   (None, 16)           80          lambda_8[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "concatenate (Concatenate)       (None, 46)           0           reshape[0][0]                    \n",
      "                                                                 reshape_1[0][0]                  \n",
      "                                                                 reshape_2[0][0]                  \n",
      "                                                                 reshape_3[0][0]                  \n",
      "                                                                 reshape_4[0][0]                  \n",
      "                                                                 reshape_5[0][0]                  \n",
      "                                                                 reshape_6[0][0]                  \n",
      "                                                                 reshape_7[0][0]                  \n",
      "                                                                 dense[0][0]                      \n",
      "__________________________________________________________________________________________________\n",
      "dense_1 (Dense)                 (None, 60)           2820        concatenate[0][0]                \n",
      "__________________________________________________________________________________________________\n",
      "dropout (Dropout)               (None, 60)           0           dense_1[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "dense_2 (Dense)                 (None, 60)           3660        dropout[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "dropout_1 (Dropout)             (None, 60)           0           dense_2[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "dense_3 (Dense)                 (None, 60)           3660        dropout_1[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "dropout_2 (Dropout)             (None, 60)           0           dense_3[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "dense_4 (Dense)                 (None, 2)            122         dropout_2[0][0]                  \n",
      "==================================================================================================\n",
      "Total params: 10,601\n",
      "Trainable params: 10,601\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7f482905f8d0>"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "set_seed()\n",
    "nn = nn_ord()\n",
    "nn.summary()\n",
    "nn.fit(X_train, to_categorical(y_train), batch_size=128, epochs=30, verbose=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generate counterfactual"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Original instance:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = X_test[0].reshape((1,) + X_test[0].shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Initialize counterfactual parameters:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "shape = X.shape\n",
    "beta = .01\n",
    "c_init = 1.\n",
    "c_steps = 5\n",
    "max_iterations = 500\n",
    "rng = (-1., 1.)  # scale features between -1 and 1\n",
    "rng_shape = (1,) + data.shape[1:]\n",
    "feature_range = ((np.ones(rng_shape) * rng[0]).astype(np.float32), \n",
    "                 (np.ones(rng_shape) * rng[1]).astype(np.float32))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Initialize explainer. Since the `Embedding` layers in `tf.keras` do not let gradients propagate through, we will only make use of the model's predict function, treat it as a black box and perform numerical gradient calculations. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "set_seed()\n",
    "\n",
    "# define predict function\n",
    "predict_fn = lambda x: nn.predict(x)\n",
    "\n",
    "cf = CounterfactualProto(predict_fn,\n",
    "                         shape,\n",
    "                         beta=beta,\n",
    "                         cat_vars=cat_vars_ord,\n",
    "                         max_iterations=max_iterations,\n",
    "                         feature_range=feature_range,\n",
    "                         c_init=c_init,\n",
    "                         c_steps=c_steps,\n",
    "                         eps=(.01, .01)  # perturbation size for numerical gradients\n",
    "                        )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Fit explainer. Please check the [documentation](https://docs.seldon.io/projects/alibi/en/stable/methods/CFProto.html) for more info about the optional arguments."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "cf.fit(X_train, d_type='abdm', disc_perc=[25, 50, 75]);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Explain instance:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "set_seed()\n",
    "explanation = cf.explain(X)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Helper function to more clearly describe explanations:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "def describe_instance(X, explanation, eps=1e-2):\n",
    "    print('Original instance: {}  -- proba: {}'.format(target_names[explanation.orig_class],\n",
    "                                                       explanation.orig_proba[0]))\n",
    "    print('Counterfactual instance: {}  -- proba: {}'.format(target_names[explanation.cf['class']],\n",
    "                                                             explanation.cf['proba'][0]))\n",
    "    print('\\nCounterfactual perturbations...')\n",
    "    print('\\nCategorical:')\n",
    "    X_orig_ord = X\n",
    "    X_cf_ord = explanation.cf['X']\n",
    "    delta_cat = {}\n",
    "    for i, (_, v) in enumerate(category_map.items()):\n",
    "        cat_orig = v[int(X_orig_ord[0, i])]\n",
    "        cat_cf = v[int(X_cf_ord[0, i])]\n",
    "        if cat_orig != cat_cf:\n",
    "            delta_cat[feature_names[i]] = [cat_orig, cat_cf]\n",
    "    if delta_cat:\n",
    "        for k, v in delta_cat.items():\n",
    "            print('{}: {}  -->   {}'.format(k, v[0], v[1]))\n",
    "    print('\\nNumerical:')\n",
    "    delta_num = X_cf_ord[0, -4:] - X_orig_ord[0, -4:]\n",
    "    n_keys = len(list(cat_vars_ord.keys()))\n",
    "    for i in range(delta_num.shape[0]):\n",
    "        if np.abs(delta_num[i]) > eps:\n",
    "            print('{}: {:.2f}  -->   {:.2f}'.format(feature_names[i+n_keys],\n",
    "                                            X_orig_ord[0,i+n_keys],\n",
    "                                            X_cf_ord[0,i+n_keys]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Original instance: <=50K  -- proba: [0.6976237  0.30237624]\n",
      "Counterfactual instance: >50K  -- proba: [0.49604183 0.5039582 ]\n",
      "\n",
      "Counterfactual perturbations...\n",
      "\n",
      "Categorical:\n",
      "\n",
      "Numerical:\n",
      "Capital Gain: -1.00  -->   -0.88\n"
     ]
    }
   ],
   "source": [
    "describe_instance(X, explanation)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The person's incomce is predicted to be above $50k by increasing his or her capital gain."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
